//// THIS CODE AND INFORMATION IS PROVIDED "AS IS" WITHOUT WARRANTY OF
//// ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING BUT NOT LIMITED TO
//// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND/OR FITNESS FOR A
//// PARTICULAR PURPOSE.
////
//// Copyright (c) Microsoft Corporation. All rights reserved

#pragma once

namespace Microsoft {
	namespace Media {
		namespace FragmentedMediaSink {

			template <class TBase = IMFAttributes>
			class CBaseAttributes : public TBase
			{
			protected:
				// This version of the constructor does not initialize the 
				// attribute store. The derived class must call Initialize() in 
				// its own constructor.
				CBaseAttributes()
				{
				}

				// This version of the constructor initializes the attribute 
				// store, but the derived class must pass an HRESULT parameter 
				// to the constructor.

				CBaseAttributes(HRESULT &hr, UINT32 cInitialSize = 0)
				{
					hr = Initialize(cInitialSize);
				}

				// The next version of the constructor uses a caller-provided 
				// implementation of IMFAttributes.

				// (Sometimes you want to delegate IMFAttributes calls to some 
				// other object that implements IMFAttributes, rather than using 
				// MFCreateAttributes.)

				CBaseAttributes(HRESULT &hr, IUnknown *pUnk)
				{
					hr = Initialize(pUnk);
				}

				virtual ~CBaseAttributes()
				{
				}

				// Initializes the object by creating the standard Media Foundation attribute store.
				HRESULT Initialize(UINT32 cInitialSize = 0)
				{
					if (_spAttributes.Get() == nullptr)
					{
						return MFCreateAttributes(&_spAttributes, cInitialSize);
					}
					else
					{
						return S_OK;
					}
				}

				// Initializes this object from a caller-provided attribute store.
				// pUnk: Pointer to an object that exposes IMFAttributes.
				HRESULT Initialize(IUnknown *pUnk)
				{
					if (_spAttributes)
					{
						_spAttributes->Release();
						_spAttributes = NULL;
					}


					return pUnk->QueryInterface(IID_PPV_ARGS(&_spAttributes));
				}

			public:

				// IMFAttributes methods

				STDMETHODIMP GetItem(REFGUID guidKey, PROPVARIANT *pValue)
				{
					assert(_spAttributes);
					return _spAttributes->GetItem(guidKey, pValue);
				}

				STDMETHODIMP GetItemType(REFGUID guidKey, MF_ATTRIBUTE_TYPE *pType)
				{
					assert(_spAttributes);
					return _spAttributes->GetItemType(guidKey, pType);
				}

				STDMETHODIMP CompareItem(REFGUID guidKey, REFPROPVARIANT Value, BOOL *pbResult)
				{
					assert(_spAttributes);
					return _spAttributes->CompareItem(guidKey, Value, pbResult);
				}

				STDMETHODIMP Compare(
					IMFAttributes *pTheirs,
					MF_ATTRIBUTES_MATCH_TYPE MatchType,
					BOOL *pbResult
					)
				{
					assert(_spAttributes);
					return _spAttributes->Compare(pTheirs, MatchType, pbResult);
				}

				STDMETHODIMP GetUINT32(REFGUID guidKey, UINT32 *punValue)
				{
					assert(_spAttributes);
					return _spAttributes->GetUINT32(guidKey, punValue);
				}

				STDMETHODIMP GetUINT64(REFGUID guidKey, UINT64 *punValue)
				{
					assert(_spAttributes);
					return _spAttributes->GetUINT64(guidKey, punValue);
				}

				STDMETHODIMP GetDouble(REFGUID guidKey, double *pfValue)
				{
					assert(_spAttributes);
					return _spAttributes->GetDouble(guidKey, pfValue);
				}

				STDMETHODIMP GetGUID(REFGUID guidKey, GUID *pguidValue)
				{
					assert(_spAttributes);
					return _spAttributes->GetGUID(guidKey, pguidValue);
				}

				STDMETHODIMP GetStringLength(REFGUID guidKey, UINT32 *pcchLength)
				{
					assert(_spAttributes);
					return _spAttributes->GetStringLength(guidKey, pcchLength);
				}

				STDMETHODIMP GetString(REFGUID guidKey, LPWSTR pwszValue, UINT32 cchBufSize, UINT32 *pcchLength)
				{
					assert(_spAttributes);
					return _spAttributes->GetString(guidKey, pwszValue, cchBufSize, pcchLength);
				}

				STDMETHODIMP GetAllocatedString(REFGUID guidKey, LPWSTR *ppwszValue, UINT32 *pcchLength)
				{
					assert(_spAttributes);
					return _spAttributes->GetAllocatedString(guidKey, ppwszValue, pcchLength);
				}

				STDMETHODIMP GetBlobSize(REFGUID guidKey, UINT32 *pcbBlobSize)
				{
					assert(_spAttributes);
					return _spAttributes->GetBlobSize(guidKey, pcbBlobSize);
				}

				STDMETHODIMP GetBlob(REFGUID guidKey, UINT8 *pBuf, UINT32 cbBufSize, UINT32 *pcbBlobSize)
				{
					assert(_spAttributes);
					return _spAttributes->GetBlob(guidKey, pBuf, cbBufSize, pcbBlobSize);
				}

				STDMETHODIMP GetAllocatedBlob(REFGUID guidKey, UINT8 **ppBuf, UINT32 *pcbSize)
				{
					assert(_spAttributes);
					return _spAttributes->GetAllocatedBlob(guidKey, ppBuf, pcbSize);
				}

				STDMETHODIMP GetUnknown(REFGUID guidKey, REFIID riid, LPVOID *ppv)
				{
					assert(_spAttributes);
					return _spAttributes->GetUnknown(guidKey, riid, ppv);
				}

				STDMETHODIMP SetItem(REFGUID guidKey, REFPROPVARIANT Value)
				{
					assert(_spAttributes);
					return _spAttributes->SetItem(guidKey, Value);
				}

				STDMETHODIMP DeleteItem(REFGUID guidKey)
				{
					assert(_spAttributes);
					return _spAttributes->DeleteItem(guidKey);
				}

				STDMETHODIMP DeleteAllItems()
				{
					assert(_spAttributes);
					return _spAttributes->DeleteAllItems();
				}

				STDMETHODIMP SetUINT32(REFGUID guidKey, UINT32 unValue)
				{
					assert(_spAttributes);
					return _spAttributes->SetUINT32(guidKey, unValue);
				}

				STDMETHODIMP SetUINT64(REFGUID guidKey, UINT64 unValue)
				{
					assert(_spAttributes);
					return _spAttributes->SetUINT64(guidKey, unValue);
				}

				STDMETHODIMP SetDouble(REFGUID guidKey, double fValue)
				{
					assert(_spAttributes);
					return _spAttributes->SetDouble(guidKey, fValue);
				}

				STDMETHODIMP SetGUID(REFGUID guidKey, REFGUID guidValue)
				{
					assert(_spAttributes);
					return _spAttributes->SetGUID(guidKey, guidValue);
				}

				STDMETHODIMP SetString(REFGUID guidKey, LPCWSTR wszValue)
				{
					assert(_spAttributes);
					return _spAttributes->SetString(guidKey, wszValue);
				}

				STDMETHODIMP SetBlob(REFGUID guidKey, const UINT8 *pBuf, UINT32 cbBufSize)
				{
					assert(_spAttributes);
					return _spAttributes->SetBlob(guidKey, pBuf, cbBufSize);
				}

				STDMETHODIMP SetUnknown(REFGUID guidKey, IUnknown *pUnknown)
				{
					assert(_spAttributes);
					return _spAttributes->SetUnknown(guidKey, pUnknown);
				}

				STDMETHODIMP LockStore()
				{
					assert(_spAttributes);
					return _spAttributes->LockStore();
				}

				STDMETHODIMP UnlockStore()
				{
					assert(_spAttributes);
					return _spAttributes->UnlockStore();
				}

				STDMETHODIMP GetCount(UINT32 *pcItems)
				{
					assert(_spAttributes);
					return _spAttributes->GetCount(pcItems);
				}

				STDMETHODIMP GetItemByIndex(UINT32 unIndex, GUID *pguidKey, PROPVARIANT *pValue)
				{
					assert(_spAttributes);
					return _spAttributes->GetItemByIndex(unIndex, pguidKey, pValue);
				}

				STDMETHODIMP CopyAllItems(IMFAttributes *pDest)
				{
					assert(_spAttributes);
					return _spAttributes->CopyAllItems(pDest);
				}

				// Helper functions

				HRESULT SerializeToStream(DWORD dwOptions, IStream *pStm)
					// dwOptions: Flags from MF_ATTRIBUTE_SERIALIZE_OPTIONS
				{
					assert(_spAttributes);
					return MFSerializeAttributesToStream(_spAttributes.Get(), dwOptions, pStm);
				}

				HRESULT DeserializeFromStream(DWORD dwOptions, IStream *pStm)
				{
					assert(_spAttributes);
					return MFDeserializeAttributesFromStream(_spAttributes.Get(), dwOptions, pStm);
				}

				// SerializeToBlob: Stores the attributes in a byte array. 
				// 
				// ppBuf: Receives a pointer to the byte array. 
				// pcbSize: Receives the size of the byte array.
				//
				// The caller must free the array using CoTaskMemFree.
				HRESULT SerializeToBlob(UINT8 **ppBuffer, UINT *pcbSize)
				{
					assert(_spAttributes);

					if (ppBuffer == NULL)
					{
						return E_POINTER;
					}
					if (pcbSize == NULL)
					{
						return E_POINTER;
					}

					HRESULT hr = S_OK;
					UINT32 cbSize = 0;
					BYTE *pBuffer = NULL;

					CHECK_HR(hr = MFGetAttributesAsBlobSize(_spAttributes.Get(), &cbSize));

					pBuffer = (BYTE*)CoTaskMemAlloc(cbSize);
					if (pBuffer == NULL)
					{
						CHECK_HR(hr = E_OUTOFMEMORY);
					}

					CHECK_HR(hr = MFGetAttributesAsBlob(_spAttributes.Get(), pBuffer, cbSize));

					*ppBuffer = pBuffer;
					*pcbSize = cbSize;

				done:
					if (FAILED(hr))
					{
						*ppBuffer = NULL;
						*pcbSize = 0;
						CoTaskMemFree(pBuffer);
					}
					return hr;
				}

				HRESULT DeserializeFromBlob(const UINT8 *pBuffer, UINT cbSize)
				{
					assert(_spAttributes);
					return MFInitAttributesFromBlob(_spAttributes.Get(), pBuffer, cbSize);
				}

				HRESULT GetRatio(REFGUID guidKey, UINT32 *pnNumerator, UINT32 *punDenominator)
				{
					assert(_spAttributes);
					return MFGetAttributeRatio(_spAttributes.Get(), guidKey, pnNumerator, punDenominator);
				}

				HRESULT SetRatio(REFGUID guidKey, UINT32 unNumerator, UINT32 unDenominator)
				{
					assert(_spAttributes);
					return MFSetAttributeRatio(_spAttributes.Get(), guidKey, unNumerator, unDenominator);
				}

				// Gets an attribute whose value represents the size of something (eg a video frame).
				HRESULT GetSize(REFGUID guidKey, UINT32 *punWidth, UINT32 *punHeight)
				{
					assert(_spAttributes);
					return MFGetAttributeSize(_spAttributes.Get(), guidKey, punWidth, punHeight);
				}

				// Sets an attribute whose value represents the size of something (eg a video frame).
				HRESULT SetSize(REFGUID guidKey, UINT32 unWidth, UINT32 unHeight)
				{
					assert(_spAttributes);
					return MFSetAttributeSize(_spAttributes.Get(), guidKey, unWidth, unHeight);
				}

			protected:
				ComPtr<IMFAttributes> _spAttributes;
			};

		}
	}
}